///
/// Copyright (C) 2015-2016, Dependable Systems Laboratory, EPFL
/// Copyright (C) 2015-2016, Cyberhaven
/// All rights reserved.
///
/// Licensed under the Cyberhaven Research License Agreement.
///

#include "CGCInterface.h"

#include <s2e/ConfigFile.h>
#include <s2e/S2E.h>
#include <s2e/S2EExecutor.h>

#include <s2e/Plugins/OSMonitors/Linux/DecreeMonitor.h>
#include <s2e/Plugins/OSMonitors/Support/ModuleExecutionDetector.h>
#include <s2e/Plugins/VulnerabilityAnalysis/DecreePovGenerator.h>
#include <s2e/Plugins/VulnerabilityAnalysis/PovGenerator.h>

#include <s2e/Plugins/Core/Events.h>

#include <algorithm>
#include <ctime>
#include <sstream>

#include <llvm/Support/FileSystem.h>
#include <llvm/Support/raw_os_ostream.h>
#include <llvm/Support/raw_ostream.h>

extern "C" {
#include <qbool.h>
#include <qstring.h>
}

namespace s2e {
namespace plugins {

S2E_DEFINE_PLUGIN(CGCInterface, "CGC interface plugin", "", "ModuleExecutionDetector", "DecreeMonitor",
                  "ProcessExecutionDetector", "PovGenerationPolicy", "DecreePovGenerator", "BasicBlockCoverage",
                  "ControlFlowGraph", "SeedSearcher", "CallSiteMonitor", "TranslationBlockCoverage");

void CGCInterface::initialize() {
    m_monitor = s2e()->getPlugin<DecreeMonitor>();
    m_detector = s2e()->getPlugin<ModuleExecutionDetector>();
    m_procDetector = s2e()->getPlugin<ProcessExecutionDetector>();
    m_povGenerator = s2e()->getPlugin<pov::DecreePovGenerator>();
    m_exploitGenerator = s2e()->getPlugin<PovGenerationPolicy>();
    m_coverage = s2e()->getPlugin<coverage::BasicBlockCoverage>();
    m_tbcoverage = s2e()->getPlugin<coverage::TranslationBlockCoverage>();
    m_cfg = s2e()->getPlugin<ControlFlowGraph>();
    m_csTracker = s2e()->getPlugin<CallSiteMonitor>();
    m_models = s2e()->getPlugin<models::StaticFunctionModels>();
    m_seedSearcher = s2e()->getPlugin<seeds::SeedSearcher>();

    ConfigFile *cfg = s2e()->getConfig();

    m_maxPovCount = cfg->getInt(getConfigKey() + ".maxPovCount", 5);
    m_disableSendingExtraDataToDB = cfg->getBool(getConfigKey() + ".disableSendingExtraDataToDB", false);

    /// XXX: need to make all config params consistent (camelcase or underscore)
    m_recordConstraints = cfg->getBool(getConfigKey() + ".recordConstraints", false);
    m_recordAllPaths = cfg->getBool(getConfigKey() + ".record_all_paths", false);

    s2e()->getCorePlugin()->onStateKill.connect(sigc::mem_fun(*this, &CGCInterface::onStateKill));
    s2e()->getCorePlugin()->onTimer.connect(sigc::mem_fun(*this, &CGCInterface::onTimer),
                                            fsigc::signal_base::HIGHEST_PRIORITY);

    m_povGenerator->onRandomInputFork.connect(sigc::mem_fun(*this, &CGCInterface::onRandomInputFork),
                                              fsigc::signal_base::HIGHEST_PRIORITY);

    m_monitor->onRandom.connect(sigc::mem_fun(*this, &CGCInterface::onRandom));

    m_exploitGenerator->onPovReady.connect(sigc::mem_fun(*this, &CGCInterface::onPovReady));

    m_cbStatsUpdateInterval = cfg->getInt(getConfigKey() + ".stats_update_interval", 10);
    m_cbStatsLastSent = 0;
    m_cbStatsChanged = false;

    uint64_t now = llvm::sys::TimeValue::now().seconds();
    m_timeOfLastCoverageReport = now;

    // How often to go through all states to report those that
    // cover new blocks. Normally coverage would get reported
    // when a path completes, but that might miss states that didn't
    // finish but have nevertheless new covered blocks.
    m_coverageTimeout = cfg->getInt(getConfigKey() + ".coverageTimeout", 60);
}

void CGCInterface::onRandom(S2EExecutionState *state, uint64_t pid, const std::vector<klee::ref<klee::Expr>> &data) {

    std::string name;
    if (!m_monitor->getProcessName(state, pid, name)) {
        return;
    }

    bool prev = m_cbStats[name].calledRandom;
    m_cbStats[name].calledRandom = true;
    if (!prev) {
        m_cbStatsChanged = true;
    }
}

void CGCInterface::onRandomInputFork(S2EExecutionState *state, const ModuleDescriptor &module) {
    uint64_t pc;
    if (!module.ToNativeBase(state->regs()->getPc(), pc)) {
        return;
    }
    auto &d = m_cbStats[module.Name].randomBranchesPc;
    m_cbStatsChanged |= d.find(pc) == d.end();
    d.insert(pc);
}

/// Send execution stats periodically
void CGCInterface::onTimer() {
    static unsigned timerIndex = 0;

    // Need to use real time, because onTimer may not be called
    // exactly once per second, and could be delayed for a long
    // time by blocking operations (e.g., constraint solver)

    // TODO: this should really be a parameter of the onTimer signal
    uint64_t curTime = llvm::sys::TimeValue::now().seconds();

    processIntermediateCoverage(curTime);

    if (!g_s2e_state || !monitor_ready()) {
        return;
    }

    if (curTime - m_cbStatsLastSent < m_cbStatsUpdateInterval) {
        return;
    }

    m_cbStatsLastSent = curTime;

    getDebugStream() << "Sending statistics\n";

    Events::PluginData data;
    data.push_back(std::make_pair("type", QOBJECT(qstring_from_str("stats"))));

    if (m_cbStatsChanged && !m_cbStats.empty()) {
        /* per-module stats */
        QDict *modules = qdict_new();
        for (auto module : m_cbStats) {
            QDict *mdata = qdict_new();
            qdict_put_obj(mdata, "called_random", QOBJECT(qbool_from_int(module.second.calledRandom)));

            QList *pcs = qlist_new();
            for (auto pc : module.second.randomBranchesPc) {
                qlist_append_obj(pcs, QOBJECT(qint_from_int(pc)));
            }

            qdict_put_obj(mdata, "random_branches_pc", QOBJECT(pcs));
            qdict_put_obj(modules, module.first.c_str(), QOBJECT(mdata));
        }

        data.push_back(std::make_pair("stats", QOBJECT(modules)));
        m_cbStatsChanged = false;
    }

    /* global stats */
    QDict *globalStats = qdict_new();

    // This information allows us to know whether the cfg lua file was loaded properly
    unsigned bbcnt = m_cfg ? m_cfg->getBasicBlockCount() : 0;
    qdict_put_obj(globalStats, "cfg_bb_count", QOBJECT(qint_from_int(bbcnt)));

    unsigned mcnt = m_models ? m_models->getFunctionModelCount() : 0;
    qdict_put_obj(globalStats, "model_count", QOBJECT(qint_from_int(mcnt)));

    data.push_back(std::make_pair("global_stats", QOBJECT(globalStats)));

    // Call site information
    std::stringstream callSiteFileName;
    callSiteFileName << "calls-" << timerIndex << ".json";
    std::string callSitePath = s2e()->getOutputFilename(callSiteFileName.str());
    m_csTracker->generateJsonFile(callSitePath);
    data.push_back(std::make_pair("callsites_filename", QOBJECT(qstring_from_str(callSitePath.c_str()))));

    ++timerIndex;

    Events::emitQMPEvent(this, data);
}

void CGCInterface::constraintsToJson(S2EExecutionState *state, std::stringstream &output) {
    output << "[";

    foreach2 (con, state->constraints.begin(), state->constraints.end()) {
        output << "\"" << *con << "\"";
        auto tmp = con;
        ++tmp;
        if (tmp != state->constraints.end()) {
            output << ",";
        }
    }

    output << "]";
}

std::string CGCInterface::constraintsToJsonFile(S2EExecutionState *state) {
    // Ensure unique file names
    static unsigned index = 0;
    std::stringstream fileName;
    fileName << "constraints-" << state->getID() << "-" << index << ".json";
    index++;

    std::string path = s2e()->getOutputFilename(fileName.str());

    std::stringstream output;
    constraintsToJson(state, output);

    std::error_code error;
    llvm::raw_fd_ostream o(path.c_str(), error, llvm::sys::fs::F_None);

    if (error) {
        getWarningsStream() << "Unable to open " << path << " - " << error.message();
    } else {
        o << output.str() << "\n";
        o.close();
    }

    return path;
}

// The server will decide what to do with the test case (verify, send to db, etc.)
void CGCInterface::sendTestcase(S2EExecutionState *state, const std::string &xmlPovPath, const std::string &cPovPath,
                                TestCaseType tcType, const PovOptions &opt, const std::string &recipeName) {
    // This ensures that we generate unique file names for coverage, constraints, etc.
    // This is important, because sendTestcase may be called several times for the same
    // state and files could be overwritten before the service had a chance to read them.
    static unsigned testCaseIndex = 0;

    Events::PluginData data;
    data.push_back(std::make_pair("type", QOBJECT(qstring_from_str("testcase"))));

    switch (tcType) {
        case PovGenerationPolicy::POV: {
            data.push_back(std::make_pair("testcase_type", QOBJECT(qstring_from_str("pov"))));
        } break;

        case PovGenerationPolicy::CRASH: {
            data.push_back(std::make_pair("testcase_type", QOBJECT(qstring_from_str("crash"))));
        } break;

        case PovGenerationPolicy::END_OF_PATH: {
            data.push_back(std::make_pair("testcase_type", QOBJECT(qstring_from_str("end_of_path"))));
        } break;

        case PovGenerationPolicy::PARTIAL_PATH: {
            data.push_back(std::make_pair("testcase_type", QOBJECT(qstring_from_str("partial_path"))));
        } break;
    }

    if (recipeName.length()) {
        data.push_back(std::make_pair("recipe_name", QOBJECT(qstring_from_str(recipeName.c_str()))));
    }

    // Files could be huge, cannot pass them through qmp
    if (m_recordConstraints) {
        std::string constraintsPath = constraintsToJsonFile(state);
        data.push_back(std::make_pair("constraints_filename", QOBJECT(qstring_from_str(constraintsPath.c_str()))));
    }

    // Basic block coverage
    // XXX: This might be deprectated. The fuzzer might not need accurate basic block info.
    // TB coverage might be just as good. On the other hand, bb coverage gives interesting
    // data for the dashboard.
    std::stringstream coverageFileName;
    coverageFileName << "coverage-" << state->getID() << "-" << testCaseIndex << ".json";
    std::string coveragePath = s2e()->getOutputFilename(coverageFileName.str());
    m_coverage->generateJsonCoverageFile(state, coveragePath);
    data.push_back(std::make_pair("coverage_filename", QOBJECT(qstring_from_str(coveragePath.c_str()))));

    // Translation block coverage
    // This is important if cfg info is unavailable. At least we get some approximation.
    // Also TB coverage would work for jitted code or any code missing in the cfg.
    // Note: there is no actual known upper bound for TB coverage, so percentage
    // can't be computed there.
    std::stringstream tbcoverageFileName;
    tbcoverageFileName << "tbcoverage-" << state->getID() << "-" << testCaseIndex << ".json";
    coveragePath = s2e()->getOutputFilename(tbcoverageFileName.str());
    m_tbcoverage->generateJsonCoverageFile(state, coveragePath);
    data.push_back(std::make_pair("tbcoverage_filename", QOBJECT(qstring_from_str(coveragePath.c_str()))));

    data.push_back(std::make_pair("fault_address", QOBJECT(qint_from_int(opt.m_faultAddress))));

    data.push_back(std::make_pair("xml_testcase_filename", QOBJECT(qstring_from_str(xmlPovPath.c_str()))));
    data.push_back(std::make_pair("c_testcase_filename", QOBJECT(qstring_from_str(cPovPath.c_str()))));
    data.push_back(std::make_pair("pov_type", QOBJECT(qint_from_int(opt.m_type))));

    data.push_back(std::make_pair("state_id", QOBJECT(qint_from_int(state->getID()))));

    // Report which seed was used to find this test case
    data.push_back(std::make_pair("seed_id", QOBJECT(qint_from_int(m_seedSearcher->getSubtreeSeedIndex(state)))));

    Events::emitQMPEvent(this, data);

    testCaseIndex++;
}

static bool GetXmlCFiles(const std::vector<std::string> &filePaths, std::string &xmlFilePath, std::string &cFilePath) {
    uint8_t mask = 0;

    for (const auto &fp : filePaths) {
        if (fp.find(".xml") != std::string::npos) {
            xmlFilePath = fp;
            mask |= 1;
        } else if (fp.find(".c") != std::string::npos) {
            cFilePath = fp;
            mask |= 2;
        }
    }

    return mask == 3;
}

void CGCInterface::onPovReady(S2EExecutionState *state, const PovOptions &opt, const std::string &recipeName,
                              const std::vector<std::string> &filePaths, TestCaseType tcType) {

    std::string xmlFilePath, cFilePath;
    if (!GetXmlCFiles(filePaths, xmlFilePath, cFilePath)) {
        getWarningsStream(state) << "Could not find xml/c files in the generated files\n";
        return;
    }

    sendTestcase(state, xmlFilePath, cFilePath, tcType, opt, recipeName);
}

bool CGCInterface::updateCoverage(S2EExecutionState *state) {
    bool hasNewCoveredBlocks = false;
    bool success = true;
    auto bmp = m_coveredTbs.acquire();
    const auto tbcoverage = m_tbcoverage->getCoverage(state);

    for (auto it : tbcoverage) {
        const auto &module = it.first;
        const auto &tbs = it.second;

        ModuleDescriptor desc;
        unsigned index = 0;
        desc.Name = module;
        if (!m_detector->getModuleId(desc, &index)) {
            continue;
        }

        for (auto tbit : tbs) {
            bool covered = false;
            if (!bmp->setCovered(index, tbit.startOffset, tbit.size, covered)) {
                success = false;
            }
            hasNewCoveredBlocks |= !covered;
        }
    }

    m_coveredTbs.release();

    // In case global coverage could not be determined, fallback
    // to per-instance coverage.
    auto cov = m_tbcoverage->getCoverage(state);
    bool lret = coverage::mergeCoverage(m_localCoveredTbs, cov);
    if (!success) {
        hasNewCoveredBlocks |= lret;
    }

    return hasNewCoveredBlocks;
}

void CGCInterface::processIntermediateCoverage(uint64_t currentTime) {
    if (currentTime - m_timeOfLastCoverageReport < m_coverageTimeout) {
        return;
    }

    getDebugStream() << "Looking for states with new covered blocks...\n";
    auto states = m_tbcoverage->getStatesWithNewBlocks();

    for (auto ks : states) {
        S2EExecutionState *state = dynamic_cast<S2EExecutionState *>(ks);
        bool hasNewBlocks = updateCoverage(state);
        if (hasNewBlocks) {
            getDebugStream(state) << "Reporting new blocks\n";
            sendCoveragePov(state, PovGenerationPolicy::PARTIAL_PATH);
        }
    }

    m_tbcoverage->clearStatesWithNewBlocks();
    m_timeOfLastCoverageReport = currentTime;
}

bool CGCInterface::sendCoveragePov(S2EExecutionState *state, TestCaseType tctype) {
    std::string prefix;
    if (tctype == PovGenerationPolicy::END_OF_PATH) {
        prefix = "kill";
    } else if (tctype == PovGenerationPolicy::PARTIAL_PATH) {
        prefix = "partial";
    } else {
        getWarningsStream(state) << "Invalid coverage tc type\n";
        return false;
    }

    pov::PovOptions opt;
    std::vector<std::string> filePaths;
    if (!m_povGenerator->generatePoV(state, opt, prefix, filePaths)) {
        getWarningsStream(state) << "Failed to generate PoV\n";
        return false;
    }

    onPovReady(state, opt, "", filePaths, tctype);

    return true;
}

void CGCInterface::onStateKill(S2EExecutionState *state) {
    getInfoStream(state) << "State was killed, generating testcase\n";

    // TODO: share coverage info between nodes
    bool coveredNewBlocks = updateCoverage(state);
    bool submitPov = coveredNewBlocks;

    if (!submitPov) {
        return;
    }

    sendCoveragePov(state, PovGenerationPolicy::END_OF_PATH);
}

} // namespace plugins
} // namespace s2e
